{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4zW6CU8F69Zp"
   },
   "source": [
    "## Keras example of supervised learning a NeuroGym task\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xIE7qx_a7D0e"
   },
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/example_neurogym_keras.ipynb)\n",
    "\n",
    "NeuroGym is a comprehensive toolkit that allows training any network model on many established neuroscience tasks using Reinforcement Learning techniques. It includes working memory tasks, value-based decision tasks and context-dependent perceptual categorization tasks.\n",
    "\n",
    "In this notebook we first show how to install the relevant toolbox.\n",
    "\n",
    "We then show how to access the available tasks and their relevant information.\n",
    "\n",
    "Finally we train an LSTM network on the Random Dots Motion task using standard supervised learning techniques (with Keras), and plot the results.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DS1FhYPG38WO"
   },
   "source": [
    "### Installation on google colab\n",
    "\n",
    "Uncomment and execute cell below if running on google colab.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 561
    },
    "colab_type": "code",
    "id": "Q-4MfhnP4AgL",
    "outputId": "609a4b19-e5c5-4ff9-ca31-52b1b40f230d"
   },
   "outputs": [],
   "source": [
    "# %tensorflow_version 1.x\n",
    "# # Install gymnasium\n",
    "# ! pip install gymnasium\n",
    "# # Install neurogym\n",
    "# ! git clone https://github.com/neurogym/neurogym.git\n",
    "# %cd neurogym/\n",
    "# ! pip install -e ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OCFMPbzX38Wj"
   },
   "source": [
    "### Task, network, and training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [],
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 496
    },
    "colab_type": "code",
    "id": "jAxTPbzL38Wl",
    "outputId": "a7a65c14-5e6b-40ab-9f85-2ef8db4355e1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period fixation 100.000000  lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n",
      "  ' time to respond (e.g. make a choice) on time.')\n",
      "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period decision 100.000000  lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n",
      "  ' time to respond (e.g. make a choice) on time.')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, None, 3)]         0         \n",
      "_________________________________________________________________\n",
      "lstm (LSTM)                  (None, None, 64)          17408     \n",
      "_________________________________________________________________\n",
      "time_distributed (TimeDistri (None, None, 3)           195       \n",
      "=================================================================\n",
      "Total params: 17,603\n",
      "Trainable params: 17,603\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "WARNING:tensorflow:From /Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "   7/2000 [..............................] - ETA: 5:35 - loss: 1.0581 - acc: 0.4196"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period fixation 100.000000  lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n",
      "  ' time to respond (e.g. make a choice) on time.')\n",
      "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period decision 100.000000  lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n",
      "  ' time to respond (e.g. make a choice) on time.')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000/2000 [==============================] - 106s 53ms/step - loss: 0.0569 - acc: 0.9835\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import neurogym as ngym\n",
    "\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.layers import Dense, LSTM, TimeDistributed, Input\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "warnings.filterwarnings('default')\n",
    "\n",
    "# Environment\n",
    "task = 'PerceptualDecisionMaking-v0'\n",
    "kwargs = {'dt': 100}\n",
    "seq_len = 100\n",
    "\n",
    "# Make supervised dataset\n",
    "dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,\n",
    "                       seq_len=seq_len)\n",
    "env = dataset.env\n",
    "obs_size = env.observation_space.shape[0]\n",
    "act_size = env.action_space.n\n",
    "\n",
    "# Model\n",
    "num_h = 64\n",
    "# from https://www.tensorflow.org/guide/keras/rnn\n",
    "xin = Input(batch_shape=(None, None, obs_size), dtype='float32')\n",
    "seq = LSTM(num_h, return_sequences=True, time_major=True)(xin)\n",
    "mlp = TimeDistributed(Dense(act_size, activation='softmax'))(seq)\n",
    "model = Model(inputs=xin, outputs=mlp)\n",
    "model.summary()\n",
    "model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# Train network\n",
    "steps_per_epoch = 2000\n",
    "data_generator = (dataset() for i in range(steps_per_epoch))\n",
    "history = model.fit(data_generator, steps_per_epoch=steps_per_epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analysis\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 122
    },
    "colab_type": "code",
    "id": "iaeFA3oklR99",
    "outputId": "f9f027df-f67e-41a6-d5d3-dd2d98555d6e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.895\n"
     ]
    }
   ],
   "source": [
    "perf = 0\n",
    "num_trial = 200\n",
    "for i in range(num_trial):\n",
    "    env.new_trial()\n",
    "    obs, gt = env.obs, env.gt\n",
    "    obs = obs[:, np.newaxis, :]\n",
    "\n",
    "    action_pred = model.predict(obs)\n",
    "    action_pred = np.argmax(action_pred, axis=-1)\n",
    "    perf += gt[-1] == action_pred[-1, 0]\n",
    "\n",
    "perf /= num_trial\n",
    "print(perf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 225
    },
    "colab_type": "code",
    "id": "TdpkVOLhuRAK",
    "outputId": "e1abdc87-eaff-4e32-d395-1af3513a80ff"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAADQCAYAAAA53LuNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9bn48c+TnSQsYV8CBBRQtiA7iKDggkpxp1iqovXyokqv9r4ulV5bvfK79uLS2t66iwq1NqKooAVXUKgbEhBQFtkaJOxhh5CQ5fn9cU7SyUzWYWbOJDzv12temTnnfM/3mTPfeXLme875HlFVjDHGRF6M1wEYY8zZyhKwMcZ4xBKwMcZ4xBKwMcZ4xBKwMcZ4JM7rAOoqQRI1iZQ6lytuVfcycSdL6lwGoDg1NqhypcEVI+5U3c9kKY2T4CoL8l92SVKQZ9uU1j3OhOPB1RXsNpEgqtNgN39xcO+tuFFwFQbTJhOPlgZVV2FacI0rpiioYsQfDyLOk6eCqus4h/NUtZX/9HqXgJNIYYiMqXO5vBuG1blMq+xjdS4DsPfCpkGVOx1cMVp+W1znMvmtgsv2wX6Rj/ase4wAMfl1/1KmfxJcAjjVIrhtEldQ96RYFOR2TNkX3HY8kBkfVLnTzer+3jLeDS5Jbb8hKahySfuDS9zpS47XvdDX3wZV18c6f0dl060LwhhjPGIJ2BhjPGIJ2BhjPGIJ2BhjPGIJ2BhjPOJ5AhaRFBGJcZ93F5HxIhLcIVtjjKlHPE/AwHIgSUQ6AB8CtwBzPI3IGGMiIBoSsKhqPnA98LSq3gT08jgmY4wJu6hIwCIyDJgELHKnBXlNmDHG1B/RkIDvAX4NvK2q60WkK/CJxzEZY0zYeX4psqoux+kHLnu9Hfh37yIyxpjI8DwBi0h34D+BDHziUdXRXsVkjDGR4HkCBt4AngVmA8ENP2aMMfVQNCTgYlV9xusgjDEm0qLhINy7InKXiLQTkeZlD6+DMsaYcIuGPeDb3L/TfaYp0NWDWIwxJmI8T8Cq2sXrGIwxxgueJ2B33IefAyPdSZ8Cz6lqkDcaMcaY+sHzBAw8A8QDT7uvb3Gn3elZRMYYEwHRkIAHqWqmz+ulIrLWs2iMMSZCouEsiBIROafshXspsp0PbIxp8KJhD3g68ImIbAcE6Azc7m1IxhgTfp4nYFVdIiLdgB7upO9VtdDLmIwxJhI8S8AiMlpVl4rI9X6zzhURVPWtUNYXW1D3MltuaRxcZaWlQRWLOyVBlds5tu5l2i/VoOoquvpIUOUSvmsWVLnT7et+MszOK4IczTTIz63lOYfqXObggSZB1SWfJgRVLr9zcVDlUnLqniJyRycHVZcEuf2LU4Nry1snptS5zLlfB1VVlbzcAx4FLAV+VMk8BUKagI0xJtp4loBV9UH36UxV/afvPBGxizOMMQ1eNJwF8WYl0+ZHPApjjIkwL/uAz8O591tTv37gJkCSN1EZY0zkeNkH3AMYBzSjYj/wceDfPInIGGMiyMs+4IXAQhEZpqpfehWHMcZ4xfPzgIFvRORunO6I8q4HVb3Du5CMMSb8ouEg3CtAW+AKYBmQjtMNYYwxDVo0JOBzVfW3wElVnQtcDQzxOCZjjAm7aEjAZZc6HRGR3kBToLWH8RhjTEREQx/w8yKSBvwWeAdIdZ8bY0yDFg0J+GVVLcHp/7X7wBljzhrR0AXxTxF5XkTGiEhwo9EYY0w9FA0J+DzgY+BuIEdEnhSRER7HZIwxYed5AlbVfFV9XVWvB/rhXIq8zOOwjDEm7DxPwAAiMkpEngZW4VyMMcHjkIwxJuw8PwgnIjnAN8DrwHRVPeltRMYYExmeJmARiQVeUtWZXsZhjDFe8LQLwj39bJyXMRhjjFc874IAPheRJ4F5QHn3g6qu9i4kY4wJv2hIwP3cv77dEAqM9iAWY4yJGM8TsKpe4nUMxhjjBc9PQxORNiLyooi8577uKSI/8zouY4wJN88TMDAH+ABo777eDNzrWTTGGBMhnndBAC1V9XUR+TWAqhaLSElVCzfuWcpF8wrqXMmIlOfqXGb6xhvrXAbgnLS8oMo1jisMqlz3lL11LnP9lWuDquv9k+cHVe6viYODKte5yeE6l3mty9Kg6vqfvPOCKjepaXady/xu7xVB1bWhbZugyv3pnPeDKvfUD3U/FBMbUxpUXX/u+npQ5S5dek9Q5S7ssa3OZQ4EVVPVomEP+KSItMA58IaIDAWOehuSMcaEXzTsAf8HzjjA54jI50ArILhdT2OMqUc8T8CqulpERuHcpl6A71W1qIZixhhT73neBSEiNwGNVHU9cC0wT0T6exyWMcaEnecJGPitqh53xwAeA7wIPONxTMYYE3bRkIDLzni4GnhBVRcBCR7GY4wxERENCXiXiDwH/BhYLCKJREdcxhgTVtGQ6CbgXIhxhaoeAZoD070NyRhjws/zBKyq+UAOcKWI/AJop6ofehuVMcaEn+cJWEQeAOYCLYCWwMsi8htvozLGmPDz/DxgYBKQqaoFACIyC1gD/I+nURljTJh5vgcM7Ma5EWeZRGCXR7EYY0zEeLYHLCJ/xhn/4SiwXkQ+cmddCnztVVzGGBMpXnZBlA0htQFYgpOMi4FPPIvIGGMiyMsE/DfgYeAOYAfOOBCdgJeB//IwLmOMiQgv+4AfBdKALqo6QFX7A12BpsBjHsZljDER4WUCHgdMUdXjZRNU9Rjwc5zLko0xpkHzMgGrqmolE0twB2c3xpiGzMsEvEFEbvWfKCI/BTZ5EI8xxkSUlwfh7gbeEpE7gFXutIFAI+A6z6IyxpgI8SwBq+ouYIiIjAZ6uZMXq+oSr2IyxphI8vxSZFVdCgR3G1tjjKnHouFSZGOMOStZAjbGGI9IJWeCRTUROYBz5VxlWgJ5EQynKhZHoGiJxeIIFC2xNOQ4OqtqK/+J9S4BV0dEslV1oMURXXFA9MRicQSKlljOxjisC8IYYzxiCdgYYzzS0BLw814H4LI4AkVLLBZHoGiJ5ayLo0H1ARtjTH3S0PaAjTGm3rAEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHgnbHTFE5CWcW8/vV9XelcwX4E/AVUA+MFlVV9e03pYtW2pGRkaIozXGmDOz71gBbZokVTpv1apVeZUNRxnOWxLNAZ4E/lLF/CuBbu5jCPCM+7daGRkZZGdnhyhEY4wJjYwZi8iedXWl80Sk0jHMw9YFoarLgUPVLHIN8Bd1fAU0E5F24YrHGGPCobS0lFF/mI9S93F1vLwpZwdgp8/rXHfaHv8FRWQKMAWgU6dOEQnOGGNq8sRHm/n90i/YlXgHLWL/nYwZAsA9Y7rxy8u611je87si14aqPo87RNzAgQNt+DZjTFT45WXdObd1KndnzQMgp4ouiKp4mYB3AR19Xqe70+qsqKiI3NxcCgoKQhKYOTNJSUmkp6cTHx/vdSjGhN23u46SFJvK6ZLSOpf1MgG/A0wTkddwDr4dVdWA7ofayM3NpXHjxmRkZOCcXGG8oqocPHiQ3NxcunTp4nU4xoTdKxseplFae36eeWedy4btIJyIZAFfAj1EJFdEfiYiU0VkqrvIYmA7sBV4Abgr2LoKCgpo0aKFJd8oICK0aNHCfo2Ys0JxSQmbT7xNXNKOWvX5+gvbHrCq3lzDfAXuDlV9lnyjh30W5mzx6fZ1lEo+g9sPCqq8XQkXQgsWLEBE2LRpU7XLzZkzh927d5e/vvPOO9mwYUO4wzPGhNjiTZ8DcEX34UGVP6sT8BMfbQ7p+rKyshgxYgRZWVnVLuefgGfPnk3Pnj1DGosxJvy+yl2JaAJjewwMqvxZnYD/tGRLyNZ14sQJPvvsM1588UVee+218umPPPIIffr0ITMzkxkzZjB//nyys7OZNGkS/fr149SpU1x88cXlV/dlZWXRp08fevfuzX333Ve+ntTUVO6//34yMzMZOnQo+/btA+CNN96gd+/eZGZmMnLkyJC9H2NMzfJOFNAqoR9J8QlBla8X5wHXxUPvrmfD7mO1Xv7Hz31Z4zI92zfhwR/1qnaZhQsXMnbsWLp3706LFi1YtWoV+/fvZ+HChaxYsYLk5GQOHTpE8+bNefLJJ3n88ccZOLDif83du3dz3333sWrVKtLS0rj88stZsGAB1157LSdPnmTo0KE8/PDD/OpXv+KFF17gN7/5DTNnzuSDDz6gQ4cOHDlypNbv2xhzZkpKlbjjt/PzgR1rXrgKZ90ecO7hfFb88xAr/ulcJV32PPdw/hmtNysri4kTJwIwceJEsrKy+Pjjj7n99ttJTk4GoHnz5tWuY+XKlVx88cW0atWKuLg4Jk2axPLlywFISEhg3LhxAAwYMICcnBwALrzwQiZPnswLL7xASUnJGb0HY0ztbT9wgvzTJfTp0DTodTS4PeCa9lR9ZcxYVOcrVypz6NAhli5dyrfffouIUFJSgohw0003nfG6y8THx5efXRAbG0txcTEAzz77LCtWrGDRokUMGDCAVatW0aJFi5DVa4yp3FMr5rA78Y+0a/5B0Os46/aAw2H+/Pnccsst7Nixg5ycHHbu3EmXLl1o2rQpL7/8Mvn5zt71oUPOXnfjxo05fvx4wHoGDx7MsmXLyMvLo6SkhKysLEaNGlVt3du2bWPIkCHMnDmTVq1asXPnzmqXN8aExpe5X1EsuxnYMSPodZzVCfieMd1Csp6srCyuu+66CtNuuOEG9uzZw/jx4xk4cCD9+vXj8ccfB2Dy5MlMnTq1/CBcmXbt2jFr1iwuueQSMjMzGTBgANdcc021dU+fPr38oN3w4cPJzMwMyXsyxlRv6+F1tEjoQUJc8B0J4lwPUX8MHDhQ/ccD3rhxI+eff75HEZnK2GdiGrKCotMkP9yYoW0m8sXP59a4vIisUtWAc9XO6j1gY4wJxvvfZ6NymqEdg7sCrowlYGOMqaOt+/NJLr6I8edVf4ymJpaAjTGmjo4ca0PXmPsZ2TXgdpd1UqcELCJpItL3jGo0xph6LvuHHfTu0JSYmDMbeKrGBCwin4pIExFpDqwGXhCRP5xRrcYYU0+dKCzgw0PXsV9eOeN11WYPuKmqHgOux7mJ5hDg0jOu2Rhj6qHFm1agUszADn3OeF21ScBx7t2KJwB/P+MaG6h9+/bxk5/8hK5duzJgwACGDRvG22+/HfE4MjIyyMvLC5j+u9/9Lqj1LViwoMJQmb4DBxlzNvpwizN+zI/Ou+iM11WbBDwT+ADYqqorRaQrELphxBoAVeXaa69l5MiRbN++nVWrVvHaa6+Rm5sbsGzZJcSRVlUCVlVKS6u+l5V/AjbmbJe9O5tYUhmecebnudeYgFX1DVXtq6p3ua+3q+oNZ1xzA7J06VISEhKYOnVq+bTOnTvzi1/8AnDG/x0/fjyjR49mzJgxqCrTp0+nd+/e9OnTh3nznDuqfvrpp+UD7gBMmzaNOXPmAM6e7YMPPkj//v3p06dP+aDvBw8e5PLLL6dXr17ceeedVHZhzYwZMzh16hT9+vVj0qRJ5OTk0KNHD2699VZ69+7Nzp07SU1NLV9+/vz5TJ48mS+++IJ33nmH6dOn069fP7Zt2wY4Q2AOHjyY7t27849//CO0G9OYKPfPY+tondiTmJgzP4msxmvoRKQV8G9Ahu/yqnrHGdceJhfPuThg2oReE7hr0F3kF+Vz1atXBcyf3G8yk/tNJi8/jxtfv7HCvE8nf1ptfevXr6d///7VLrN69WrWrVtH8+bNefPNN1mzZg1r164lLy+PQYMG1Wos35YtW7J69WqefvppHn/8cWbPns1DDz3EiBEjeOCBB1i0aBEvvvhiQLlZs2bx5JNPsmbNGgBycnLYsmULc+fOZejQoVXWN3z4cMaPH8+4ceO48cZ/bZPi4mK+/vprFi9ezEMPPcTHH39cY+zGNAQFRSUkFVzHFb0yQrK+2qTwhUBT4GNgkc/DVOHuu+8mMzOTQYP+dZXMZZddVj4c5WeffcbNN99MbGwsbdq0YdSoUaxcubLG9V5//fVAxeEoly9fzk9/+lMArr76atLS0moVY+fOnatNvnWNw5izwaa9x2lUPJIf97k2JOurzSgSyap6X82LRY/q9liT45Ornd8yuWWNe7z+evXqxZtvvln++qmnniIvL6/CgOspKSk1ricuLq5Cf6z/nYUTExOBisNRBss/Ht8badZ0R+NQxmFMffL+ppWclm307nBJSNZXmz3gv4tI4G92U2706NEUFBTwzDPPlE8rG4KyMhdddBHz5s2jpKSEAwcOsHz5cgYPHkznzp3ZsGEDhYWFHDlyhCVLltRY98iRI/nb3/4GwHvvvcfhw4crXS4+Pp6ioqIq19OmTRs2btxIaWlphbM3qho605iz0asbnuJA0gO0b5oUkvXVJgHfg5OEC0TkuPuo/T1/zgIiwoIFC1i2bBldunRh8ODB3HbbbTzyyCOVLn/dddfRt29fMjMzGT16NI8++iht27alY8eOTJgwgd69ezNhwgQuuOCCGut+8MEHWb58Ob169eKtt96iU6dOlS43ZcoU+vbty6RJkyqdP2vWLMaNG8fw4cNp165d+fSJEyfy2GOPccEFF5QfhDPmbJVz7FvaJIXmABzYcJQmTOwzMQ3NwZPHaflYMy5Ln8qHdz5Vp7JnNByliIwXkcfdx7iaS5SXGysi34vIVhGZUcn8ySJyQETWuI87a7tuY4yJpIUbPgcpZUTGkJCtszanoc0CBgGvupPuEZELVfXXNZSLBZ4CLgNygZUi8o6q+p/VP09Vp9U9dGOMiZyPtzpXwF3Tc0TI1lmbsyCuAvqpaimAiMwFvgGqTcDAYJyr57a75V4DrgHssipjTL3TTC+lR0wqme27hmydte1JbubzvLb3YO4A+N4hMted5u8GEVknIvNFpGMt1x2gvvVlN2T2WZiGaPPeUkZ0ujik66xNAv5f4BsRmePu/a4CHg5R/e8CGaraF/gIqPTmSiIyRUSyRST7wIEDAfOTkpI4ePCgffGjgKpy8OBBkpJCc5qOMdFgz7HDrDo0m7bNj4R0vTV2Qahqloh8itMPDHCfqu6txbp3Ab57tOnuNN91H/R5ORt4tIoYngeeB+csCP/56enp5ObmUllyNpGXlJREenq612EYEzIL13/Gkfi/kpJyZUjXW2UCFpHzVHWTiJQNclA2tFd7EWmvqqtrWPdKoJuIdMFJvBOBn/jV0U5V97gvxwMb6/wOcC4y6NKlSzBFjTGmRku3fwXA+BAegIPq94D/A5gC/L6SeQqMrm7FqlosItNwhrKMBV5S1fUiMhPIVtV3gH8XkfFAMXAImFz3t2CMMeG1dt9qEmhFr7aVX+gUrCoTsKpOcZ9eqaoVBgcQkVp18KnqYmCx37QHfJ7/mprPpjDGGE/tPPEd7ZN7hXy9tTkI90UtpxljTIOz++gRCkoP0Kd19UPOBqO6PuC2OKeNNRKRC4Cy4bKaAMkhj8QYY6JQzoFiOha8zt2DQn9D+Or6gK/A6ZNNx+kHLkvAx4D/CnkkxhgThdbtOooQy+DO7UO+7ur6gOcCc0XkBlV9s6rljDGmIZu99n/RxiWkpVwd8nXXpg94gIiUXwknImki8j8hj8QYY6LQd0cWE58UeIPdUKhNAr5SVcsv/1DVwzjjQxhjTIO2LW8PhbqHvmE4AAe1S8CxIpJY9kJEGgGJ1SxvjDENwoL1zl2/L+kS3P0Ta1Kb0dBeBZaIyMs4B+ImU8WYDcYY05Asy1kBwPheob0CrkxtxoJ4RETWApfiXAH3AdA5LNEYY0wU2XfsNM1i+tE5rVVY1l/b4Sj34STfm3AuQQ5qzAZjjKlPYk9cyy3nvhS29Vd3IUZ34Gb3kQfMw7mHXGjux2yMMVFs/7ECdh8toG96bYdAr7vq9oA34eztjlPVEar6Z6AkbJEYY0wUeTH7LXITf0Zq6r6w1VFdAr4e2AN8IiIviMgY/nU1nDHGNGjLc1ZQIvsZdc55YaujygSsqgtUdSJwHvAJcC/QWkSeEZHLwxaRMcZEgQ1535Ac24l2TdLCVkeNB+FU9aSq/k1Vf4QzLsQ3wH1hi8gYY6LAnlPr6ZzaO6x11PYsCMC5Ck5Vn1fVMeEKyBhjvLZ293aKOES/tgPCWk+dErAxxpwNvtt9mNTiK/nReeHd17QEbIwxfvYdSqVV8d1c2+vCsNZjCdgYY/x8uWMz3Von0yghNqz1WAI2xhgfpaWlLNh9C3tjngx7XZaAjTHGR3buVoo5Sr+2/cJelyVgY4zx8e5GZwjKy7oND3tdloCNMcbH5z98DRrLuPPDMwawL0vAxhjjY9PBNTSJ60qzRilhr6s2A7IbY8xZQVVpVHgdwzsnR6S+sO4Bi8hYEfleRLaKyIxK5ieKyDx3/goRyTiT+p74aHPEykWyrmDL1YcYgy1nMXpbrqHGmHv4FCWn+nD9+dfWuWwwwpaARSQWeAq4EugJ3CwiPf0W+xlwWFXPBZ4AHjmTOv+0ZEvEykWyrmDL1YcYgy1nMXpbrqHGuGhjNgUx6+jZPvzdDxDeLojBwFZV3Q4gIq8B1wAbfJa5Bvhv9/l84EkREVXVula2dNMuTsZ8zv3v76kwvXOT7nRs3I2C4nyy930SUK5rU+d/wvzV3/PN/n8EzO/WrC9tUjpyrPAQ6/K+LJ9+MmY797+/h/PSLqBlcnsOFxxg/cGvA8r3bDGI5kmtyTu1h02HVpeXK9O35TCaJDZnX34uWw6vDSh/QeuLSIlvQpHs5v73ZwfMH9jmEpLiktl5fAs7jlX8j38yZjvvrOlFfGwiOUc3kXtiW0D5Ye3HEiuxbDvyHXtO7qjw3gThwg7ODbA3H17L/vyKt+aOkziGtr8CgE2HVnEy5osK7y0xthGD2o4GYH3eCg4X5lUonxyXSv82owCYtXQ+x08fqTC/cUIzMls5VyKt3reM/OITFd7bY58U0bvlEABW7l1KYcmpCuVbJLXh/BYDAVix50OKSosqbP/WjTrQvblzqtEXu96jlNIK5duldOacZr0p1dJK21Z66jlkND2PotLTrNjzUcC2LZJi3vt2T7Vtr31qF04WHavQ9spirKrtlfFve/5ty7/t+Stre8Wyv9K2Vdb2dp/4J9uPbqgw72TMdt7+5rwq2x7AkLaXBrQ93xgra3tl/Nue//b3b3t5p/ZWKJ8Y2whoxnvf7qmx7a098Hl525v33WvsS1jGua1/GfB+wkGCyHW1W7HIjcBYVb3TfX0LMERVp/ks8527TK77epu7TJ7fuqYAUwA6deo0YMeOf31YT3y0mT8t2UIJx8ltdHNAHM2KbqVp8QSKZT+7ku4ImJ92egpNSsZzWnLYkzQtYH6L0/eQWnIZhTEb2Zs4PWB+y8IZpJSO4FTMavYnPhAwv3XhQzQqHUB+zBccSPxdwPw2hY+RVHo+J2KXcDDhiYD57Qr+TIJ24VjsuxxOeC5gfvuC2cRrW47Gvc6R+L8EzE8/9SqxNOVw3FyOxb8RML/TqbcR4jkU/xzH496tOFPj6FywAIC8+D9yMu7jCrNjNJWOBa8BcCDhd+THflFhfmxpK9ILXwZgX8JvKYj9psL8+NJOtC98GoC9Cf9JYeymCvMTSnvQrvD3AOxOnEZRTE6F+UklF9Dm9P8DIDfxZ5TEVBw4u1HJMFqfvh+AnUk/oVSOVZifUjyalkX/AcCOpGtBiivMb1x8Nc2Lfo5SxA+NrsNfk6IbSSuebG2vgbW9xJJetD3t/Bi/Z0w3fnlZ94DY60pEVqnqwIDp9SEB+xo4cKBmZ2cHTN+y/yijnvgrz/60f4XpLZJb06JRK06XnCbnSOBPklbJbbn5uQ0snDaIHUcD9xBbp7SjWVJz8otOknssp3z61L+u5tmf9qddajqNE5ty8vRxdh3/IaB8h8adSElozPHCo+w5kVterkx6kwyS41M4WniYfSd2B5Tv1LQrSXGNuOyPi3jqloyA+RnNupEQm8DBUwc4mL+/wrypf13Nh7+4mbiYOA7k7+PwqcDNem7z84mRGPaf3MORgkMV3hsidG/u/ELYe2IXxwor7qHGxMRybpozWPWu4z9wy0vLKry3uJh4uqY5jTf3WA75RScrlI+PTaRLs3MZ+8d/8PztHSgorrgHmxTXiE5NuwKQc3Qbp4sLKry3v9x+EelNnG2y/fBmikuLKpRPSWhMh8adANh6eBOlpSUVtn/jxKa0S00HYPOhDeD3XWialEablPaUaimX/t+rAW0rrVFLWiW3obi0mO2Hvw/Ytne/sp2P7v1RtW0vrVELCosLKrS9shirantl/Nuef9vyb3v+ytreZX98j6du6RQwv6ztHS44xIGTFff+p/51Ne9P+3GVbQ+ga1qPgLbnG2Nlba+cX9v7yYtLK7w3/7Z38vTxCsXjYuK5a+4+3r/3omrbHsAPR7dXaHvTXsnlh1k3BbyfM1FVAkZVw/IAhgEf+Lz+NfBrv2U+AIa5z+Nw7j0n1a13wIABWpXO9/29ynnVCaZcJOsKtlx9iDHYchajt+UsxroBsrWSfBbOsyBWAt1EpIuIJAATgXf8lnkHuM19fiOw1A02KPeM6RaxcpGsK9hy9SHGYMtZjN6WsxhDI2xdEAAichXwRyAWeElVHxaRmTj/Dd4RkSTgFeAC4BAwUd2DdtWs8wCwo4rZLXH2or1mcQSKllgsjkDREktDjqOzqrbynxjWBBxpIpKtlfWzWByei5ZYLI5A0RLL2RiHXYpsjDEesQRsjDEeaWgJ+HmvA3BZHIGiJRaLI1C0xHLWxdGg+oCNMaY+aWh7wMYYU29YAjbGGI/UywQc6WEuq4iho4h8IiIbRGS9iNxTyTIXi8hREVnjPgIv2A9NLDki8q1bR8B12uL4P3d7rBOR/pWtJwRx9PB5r2tE5JiI3Ou3TFi2iYi8JCL73cvby6Y1F5GPRGSL+zetirK3uctsEZHbKlvmDON4TEQ2udv+bRFpVkXZaj/HEMXy3yKyy2f7X1VF2Wq/YyGIY55PDDkisqaKsiHbJlV9Z71oJ+Uquzwumh84F3VsA7oCCcBaoKffMncBz7rPJwLzwhBHO6C/+7wxsLmSOK4HMr4AAAWMSURBVC4G/h6BbZIDtKxm/lXAe4AAQ4EVEfqc9uKcgB72bQKMBPoD3/lMexSY4T6fATxSSbnmwHb3b5r7PC3EcVwOxLnPH6ksjtp8jiGK5b+B/6zFZ1ftd+xM4/Cb/3vggXBvk6q+s160k7JHfdwDLh/mUlVPA2XDXPq6BpjrPp8PjBERCWUQqrpHVVe7z48DG4EOoawjhK4B/qKOr4BmItIuzHWOAbapalVXLYaUqi7HuZrSl287mAtUNsr2FcBHqnpIVQ8DHwFjQxmHqn6oqmVDrX0FpAe7/jONpZZq8x0LSRzu93ICkBXs+usQR1Xf2Yi3kzL1MQF3AHb6vM4lMPGVL+M2/KNAi3AF5HZxXACsqGT2MBFZKyLviUivMIWgwIciskqcoTv91WabhdpEqv5SRWKbALRR1bJhvPYCbSpZJtLb5g6cXyOVqelzDJVpbnfIS1X83I7kNrkI2KeqVY2eHpZt4ved9ayd1McEHFVEJBV4E7hXVY/5zV6N8xM8E/gzsCBMYYxQ1f44dx+5W0RGhqmeWhFn8KXxQOAgsJHbJhWo8zvS03MuReR+oBh4tYpFIvE5PgOcA/QD9uD8/PfSzVS/9xvybVLddzbS7aQ+JuBdQEef1+nutEqXEZE4oClwMNSBiEg8zgf5qqq+5T9fVY+p6gn3+WIgXkRahjoOVd3l/t0PvI3zE9JXbbZZKF0JrFbVff4zIrVNXPvKulrcv4GD1kZo24jIZGAcMMn9kgeoxed4xlR1n6qWqGop8EIVdURqm8QB1wPzqlom1Nukiu+sZ+2kPibgiA9zWRm37+pFYKOq/qGKZdqW9T2LyGCc7R3SfwQikiIijcue4xzw+c5vsXeAW8UxFDjq85MrHKrcq4nENvHh2w5uAxZWsswHwOUikub+HL/cnRYyIjIW+BUwXlXzq1imNp9jKGLx7fu/roo6avMdC4VLgU3q3pDBX6i3STXfWe/aSSiOLkb6gXNUfzPOkdr73WkzcRo4QBLOz9+twNdA1zDEMALnp8o6YI37uAqYCkx1l5kGrMc5ivwVMDwMcXR117/Wratse/jGITg3SN0GfAsMDONnk4KTUJv6TAv7NsFJ+HuAIpz+uZ/h9PsvAbYAHwPN3WUHArN9yt7htpWtwO1hiGMrTv9hWTspO0OnPbC4us8xDLG84raBdTiJp51/LFV9x0IZhzt9Tlm78Fk2bNukmu9sxNtJ2cMuRTbGGI/Uxy4IY4xpECwBG2OMRywBG2OMRywBG2OMRywBG2OMRywBmwZLRO53R71a546mNURE7hWRZK9jMwbsjhimgRKRYcAfgItVtdC92i4B+ALnPOhouP25OcvZHrBpqNoBeapaCOAm3BtxTvT/REQ+ARCRy0XkSxFZLSJvuOMElI1D+6g7Fu3XInKuO/0mEfnOHUxouTdvzTQUtgdsGiQ3kX4GJONc3TRPVZeJSA7uHrC7V/wWcKWqnhSR+4BEVZ3pLveCqj4sIrcCE1R1nIh8C4xV1V0i0kxVj3jyBk2DYHvApkFSZ8CfAcAU4AAwzx0Qx9dQnAG5Pxfnjgy3AZ195mf5/B3mPv8cmCMi/4YzcLkxQYvzOgBjwkVVS4BPgU/dPVf/28gIziDbN1e1Cv/nqjpVRIYAVwOrRGSAqoZrMCHTwNkesGmQxLk/XTefSf2AHcBxnNvRgDMY0IU+/bspItLdp8yPff5+6S5zjqquUNUHcPasfYcoNKZObA/YNFSpwJ/FuQFmMc4IVlNwhsp8X0R2q+olbrdElogkuuV+gzMKGECaiKwDCt1yAI+5iV1wRtBaG5F3YxokOwhnTCV8D9Z5HYtpuKwLwhhjPGJ7wMYY4xHbAzbGGI9YAjbGGI9YAjbGGI9YAjbGGI9YAjbGGI/8f3kkhjfgm2xdAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 360x216 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "_ = ngym.utils.plotting.fig_(obs[0], action_pred[0], gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OfGvOBO-84l-"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Copy of example_NeuroGym_keras.ipynb",
   "provenance": []
  },
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
